import torch
import torchvision.datasets as dsets
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pymc as pm Reconstruction Error
It’s natural to seek short cuts
Job Satisfaction Data
import numpy as np
# Standard deviations
stds = np.array([0.939, 1.017, 0.937, 0.562, 0.760, 0.524,
0.585, 0.609, 0.731, 0.711, 1.124, 1.001])
n = len(stds)
# Lower triangular correlation values as a flat list
corr_values = [
1.000,
.668, 1.000,
.635, .599, 1.000,
.263, .261, .164, 1.000,
.290, .315, .247, .486, 1.000,
.207, .245, .231, .251, .449, 1.000,
-.206, -.182, -.195, -.309, -.266, -.142, 1.000,
-.280, -.241, -.238, -.344, -.305, -.230, .753, 1.000,
-.258, -.244, -.185, -.255, -.255, -.215, .554, .587, 1.000,
.080, .096, .094, -.017, .151, .141, -.074, -.111, .016, 1.000,
.061, .028, -.035, -.058, -.051, -.003, -.040, -.040, -.018, .284, 1.000,
.113, .174, .059, .063, .138, .044, -.119, -.073, -.084, .563, .379, 1.000
]
# Fill correlation matrix
corr_matrix = np.zeros((n, n))
idx = 0
for i in range(n):
for j in range(i+1):
corr_matrix[i, j] = corr_values[idx]
corr_matrix[j, i] = corr_values[idx]
idx += 1
# Covariance matrix: Sigma = D * R * D
cov_matrix = np.outer(stds, stds) * corr_matrix
#cov_matrix_test = np.dot(np.dot(np.diag(stds), corr_matrix), np.diag(stds))
columns=["JW1","JW2","JW3", "UF1","UF2","FOR", "DA1","DA2","DA3", "EBA","ST","MI"]
corr_df = pd.DataFrame(corr_matrix, columns=columns)
cov_df = pd.DataFrame(cov_matrix, columns=columns)
cov_df
def make_sample(cov_matrix, size, columns):
sample_df = pd.DataFrame(np.random.multivariate_normal([0]*12, cov_matrix, size=size), columns=columns)
return sample_df
sample_df = make_sample(cov_matrix, 263, columns)
sample_df.head()| JW1 | JW2 | JW3 | UF1 | UF2 | FOR | DA1 | DA2 | DA3 | EBA | ST | MI | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.330438 | 0.326503 | 0.583599 | 0.298558 | 1.112858 | 0.371895 | 0.247663 | 0.704001 | 0.739605 | 0.317825 | 1.546412 | 1.282066 |
| 1 | 0.000730 | -0.598124 | -0.882404 | 0.034036 | -0.220583 | -0.443250 | 0.452083 | 0.976292 | 1.460018 | 0.208454 | -0.237027 | 0.124897 |
| 2 | -0.334853 | -0.171359 | -0.862147 | -0.696685 | 0.294389 | -0.671320 | 0.023049 | -0.117460 | 0.394511 | 0.769453 | 1.138158 | 0.216388 |
| 3 | -0.765717 | 0.554349 | 0.062522 | 0.181065 | 0.609657 | 0.595781 | -0.056995 | -0.635932 | -0.330862 | -0.424084 | -0.548190 | -0.637716 |
| 4 | -0.912169 | -0.369919 | -0.210114 | 0.185822 | -0.755927 | -0.490341 | -0.472250 | -0.419797 | 0.123084 | 1.157508 | -0.840009 | -0.144719 |
data = sample_df.corr()
def plot_heatmap(data, title="Correlation Matrix", vmin=-.2, vmax=.2, ax=None, figsize=(10, 6), colorbar=True):
data_matrix = data.values
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(data, cmap='viridis', vmin=vmin, vmax=vmax)
for i in range(data_matrix.shape[0]):
for j in range(data_matrix.shape[1]):
text = ax.text(
j, i, # x, y coordinates
f"{data_matrix[i, j]:.2f}", # text to display
ha="center", va="center", # center alignment
color="white" if data_matrix[i,j] < 0.5 else "black" # contrast color
)
ax.set_title(title)
ax.set_xticklabels(data.columns)
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticklabels(data.index)
ax.set_yticks(np.arange(data.shape[0]))
if colorbar:
plt.colorbar(im)
plot_heatmap(data, vmin=-1, vmax=1)/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
X = make_sample(cov_matrix, 100, columns=columns)
U, S, VT = np.linalg.svd(X, full_matrices=False)ranks = [2, 5, 12]
reconstructions = []
for k in ranks:
X_k = U[:, :k] @ np.diag(S[:k]) @ VT[:k, :]
reconstructions.append(X_k)
# Plot original and reconstructed matrices
fig, axes = plt.subplots(1, len(ranks) + 1, figsize=(10,15))
axes[0].imshow(X, cmap='viridis')
axes[0].set_title("Original")
axes[0].axis("off")
for ax, k, X_k in zip(axes[1:], ranks, reconstructions):
ax.imshow(X_k, cmap='viridis')
ax.set_title(f"Rank {k}")
ax.axis("off")
plt.suptitle("Reconstruction of Data Using SVD \n various truncation options",fontsize=12, x=.5, y=1.01)
plt.tight_layout()
plt.show()Variational Auto-Encoders
class NumericVAE(nn.Module):
def __init__(self, n_features, hidden_dim=64, latent_dim=8):
super().__init__()
# ---------- ENCODER ----------
# First layer: compress input features into a hidden representation
self.fc1 = nn.Linear(n_features, hidden_dim)
# Latent space parameters (q(z|x)): mean and log-variance
self.fc_mu = nn.Linear(hidden_dim, latent_dim) # μ(x)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # log(σ^2(x))
# ---------- DECODER ----------
# First layer: map latent variable z back into hidden representation
self.fc2 = nn.Linear(latent_dim, hidden_dim)
# Output distribution parameters for reconstruction p(x|z)
# For numeric data, we predict both mean and log-variance per feature
self.fc_out_mu = nn.Linear(hidden_dim, n_features) # μ_x(z)
self.fc_out_logvar = nn.Linear(hidden_dim, n_features) # log(σ^2_x(z))
# ENCODER forward pass: input x -> latent mean, log-variance
def encode(self, x):
h = F.relu(self.fc1(x)) # Hidden layer with ReLU
mu = self.fc_mu(h) # Latent mean vector
logvar = self.fc_logvar(h) # Latent log-variance vector
return mu, logvar
# Reparameterization trick: sample z = μ + σ * ε (ε ~ N(0,1))
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar) # σ = exp(0.5 * logvar)
eps = torch.randn_like(std) # ε ~ N(0, I)
return mu + eps * std # z = μ + σ * ε
# DECODER forward pass: latent z -> reconstructed mean, log-variance
def decode(self, z):
h = F.relu(self.fc2(z)) # Hidden layer with ReLU
recon_mu = self.fc_out_mu(h) # Mean of reconstructed features
recon_logvar = self.fc_out_logvar(h)# Log-variance of reconstructed features
return recon_mu, recon_logvar
# Full forward pass: input x -> reconstructed (mean, logvar), latent params
def forward(self, x):
mu, logvar = self.encode(x) # q(z|x)
z = self.reparameterize(mu, logvar) # Sample z from q(z|x)
recon_mu, recon_logvar = self.decode(z)# p(x|z)
return (recon_mu, recon_logvar), mu, logvar
# Sample new synthetic data: z ~ N(0,I), decode to x
def generate(self, n_samples=100):
self.eval()
with torch.no_grad():
# Sample z from standard normal prior
z = torch.randn(n_samples, self.fc_mu.out_features)
# Decode to get reconstruction distribution parameters
cont_mu, cont_logvar = self.decode(z)
# Sample from reconstructed Gaussian: μ_x + σ_x * ε
return cont_mu + torch.exp(0.5 * cont_logvar) * torch.randn_like(cont_mu)def vae_loss(recon_mu, recon_logvar, x, mu, logvar):
# Reconstruction loss: Gaussian log likelihood
recon_var = torch.exp(recon_logvar)
recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) + (x - recon_mu) ** 2 / recon_var)
recon_loss = recon_nll.sum(dim=1).mean() # sum over features, mean over batch
# KL divergence: D_KL(q(z|x) || p(z)) where p(z)=N(0,I)
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
kl_loss = kl_div.mean()
return recon_loss + kl_loss, recon_loss, kl_lossdef prep_data_vae(sample_size=1000):
sample_df = make_sample(cov_matrix=cov_matrix, size=sample_size, columns=columns)
X_train, X_test = train_test_split(sample_df.values, test_size=0.2, random_state=890)
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
train_loader = torch.utils.data.DataLoader(X_train, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(X_test, batch_size=32)
return train_loader, test_loader
# | output: false
def train_vae(vae, optimizer, train, test, patience=30, wait=10, n_epochs=1000):
best_loss = float('inf')
losses = []
for epoch in range(n_epochs):
vae.train()
train_loss = 0.0
for batch in train:
optimizer.zero_grad()
(recon_mu, recon_logvar), mu, logvar = vae(batch)
loss, recon_loss, kl_loss = vae_loss(recon_mu, recon_logvar, batch, mu, logvar)
loss.backward()
optimizer.step()
train_loss += loss.item() * batch.size(0)
avg_train_loss = train_loss / train.dataset.shape[0]
# --- Test Loop ---
vae.eval()
test_loss = 0.0
with torch.no_grad():
for batch in test:
(recon_mu, recon_logvar), mu, logvar = vae(batch)
loss, _, _ = vae_loss(recon_mu, recon_logvar, batch, mu, logvar)
test_loss += loss.item() * batch.size(0)
avg_test_loss = test_loss / test.dataset.shape[0]
print(f"Epoch {epoch+1}/{n_epochs} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
if test_loss < best_loss - 1e-4:
best_loss, wait = test_loss, 0
best_state = vae.state_dict()
else:
wait += 1
if wait >= patience:
print(f"Early stopping at epoch {epoch+1}")
vae.load_state_dict(best_state) # restore best
break
losses.append([avg_train_loss, avg_test_loss, best_loss])
return vae, pd.DataFrame(losses, columns=['train_loss', 'test_loss', 'best_loss'])
train_500, test_500 = prep_data_vae(500)
vae = NumericVAE(n_features=train_500.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_500, losses_df_500 = train_vae(vae, optimizer, train_500, test_500)
train_1000, test_1000 = prep_data_vae(1000)
vae = NumericVAE(n_features=train_1000.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_1000, losses_df_1000 = train_vae(vae, optimizer, train_1000, test_1000)
train_10_000, test_10_000 = prep_data_vae(10_000)
vae = NumericVAE(n_features=train_10_000.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_10_000, losses_df_10_000 = train_vae(vae, optimizer, train_10_000, test_10_000)Epoch 1/1000 | Train Loss: 15.3680 | Test Loss: 14.4466
Epoch 2/1000 | Train Loss: 14.6294 | Test Loss: 14.0013
Epoch 3/1000 | Train Loss: 14.4412 | Test Loss: 13.8137
Epoch 4/1000 | Train Loss: 14.2084 | Test Loss: 13.6134
Epoch 5/1000 | Train Loss: 14.0166 | Test Loss: 13.5724
Epoch 6/1000 | Train Loss: 13.8730 | Test Loss: 13.2442
Epoch 7/1000 | Train Loss: 13.7800 | Test Loss: 13.1401
Epoch 8/1000 | Train Loss: 13.6590 | Test Loss: 13.2363
Epoch 9/1000 | Train Loss: 13.5496 | Test Loss: 12.9563
Epoch 10/1000 | Train Loss: 13.4809 | Test Loss: 13.1903
Epoch 11/1000 | Train Loss: 13.4523 | Test Loss: 12.8897
Epoch 12/1000 | Train Loss: 13.4471 | Test Loss: 13.0796
Epoch 13/1000 | Train Loss: 13.3620 | Test Loss: 12.9527
Epoch 14/1000 | Train Loss: 13.3164 | Test Loss: 12.8384
Epoch 15/1000 | Train Loss: 13.4159 | Test Loss: 12.8698
Epoch 16/1000 | Train Loss: 13.3627 | Test Loss: 12.9657
Epoch 17/1000 | Train Loss: 13.3054 | Test Loss: 12.8151
Epoch 18/1000 | Train Loss: 13.2753 | Test Loss: 12.9287
Epoch 19/1000 | Train Loss: 13.2456 | Test Loss: 12.9137
Epoch 20/1000 | Train Loss: 13.0975 | Test Loss: 12.8506
Epoch 21/1000 | Train Loss: 13.0006 | Test Loss: 12.7456
Epoch 22/1000 | Train Loss: 13.0891 | Test Loss: 12.7760
Epoch 23/1000 | Train Loss: 13.0707 | Test Loss: 12.5124
Epoch 24/1000 | Train Loss: 13.0744 | Test Loss: 12.7329
Epoch 25/1000 | Train Loss: 13.0097 | Test Loss: 12.5453
Epoch 26/1000 | Train Loss: 13.1350 | Test Loss: 12.6985
Epoch 27/1000 | Train Loss: 13.0438 | Test Loss: 12.6420
Epoch 28/1000 | Train Loss: 12.9861 | Test Loss: 12.5609
Epoch 29/1000 | Train Loss: 12.9717 | Test Loss: 12.7911
Epoch 30/1000 | Train Loss: 12.9789 | Test Loss: 12.6480
Epoch 31/1000 | Train Loss: 12.9025 | Test Loss: 12.4954
Epoch 32/1000 | Train Loss: 12.8808 | Test Loss: 12.3960
Epoch 33/1000 | Train Loss: 12.6912 | Test Loss: 12.5765
Epoch 34/1000 | Train Loss: 12.7194 | Test Loss: 12.4289
Epoch 35/1000 | Train Loss: 12.7199 | Test Loss: 12.7059
Epoch 36/1000 | Train Loss: 12.7432 | Test Loss: 12.8295
Epoch 37/1000 | Train Loss: 12.5384 | Test Loss: 12.5140
Epoch 38/1000 | Train Loss: 12.7536 | Test Loss: 12.4514
Epoch 39/1000 | Train Loss: 12.5349 | Test Loss: 12.2500
Epoch 40/1000 | Train Loss: 12.5994 | Test Loss: 12.1384
Epoch 41/1000 | Train Loss: 12.5682 | Test Loss: 12.2923
Epoch 42/1000 | Train Loss: 12.4889 | Test Loss: 12.4402
Epoch 43/1000 | Train Loss: 12.5832 | Test Loss: 12.3492
Epoch 44/1000 | Train Loss: 12.6140 | Test Loss: 12.4188
Epoch 45/1000 | Train Loss: 12.3821 | Test Loss: 12.6913
Epoch 46/1000 | Train Loss: 12.4215 | Test Loss: 12.0890
Epoch 47/1000 | Train Loss: 12.3097 | Test Loss: 12.0461
Epoch 48/1000 | Train Loss: 12.2807 | Test Loss: 12.2536
Epoch 49/1000 | Train Loss: 12.4938 | Test Loss: 12.2436
Epoch 50/1000 | Train Loss: 12.3915 | Test Loss: 12.1382
Epoch 51/1000 | Train Loss: 12.3161 | Test Loss: 12.1779
Epoch 52/1000 | Train Loss: 12.3445 | Test Loss: 12.2356
Epoch 53/1000 | Train Loss: 12.3148 | Test Loss: 12.2953
Epoch 54/1000 | Train Loss: 12.2382 | Test Loss: 12.2845
Epoch 55/1000 | Train Loss: 12.2473 | Test Loss: 12.2716
Epoch 56/1000 | Train Loss: 12.1182 | Test Loss: 12.5098
Epoch 57/1000 | Train Loss: 12.4766 | Test Loss: 11.8399
Epoch 58/1000 | Train Loss: 12.1878 | Test Loss: 12.0880
Epoch 59/1000 | Train Loss: 12.2158 | Test Loss: 12.2003
Epoch 60/1000 | Train Loss: 12.1366 | Test Loss: 12.0985
Epoch 61/1000 | Train Loss: 12.2614 | Test Loss: 12.1348
Epoch 62/1000 | Train Loss: 12.2824 | Test Loss: 12.1235
Epoch 63/1000 | Train Loss: 12.2330 | Test Loss: 12.0664
Epoch 64/1000 | Train Loss: 12.2218 | Test Loss: 11.7363
Epoch 65/1000 | Train Loss: 12.1137 | Test Loss: 12.1480
Epoch 66/1000 | Train Loss: 12.1029 | Test Loss: 12.1612
Epoch 67/1000 | Train Loss: 12.0378 | Test Loss: 11.9580
Epoch 68/1000 | Train Loss: 12.2043 | Test Loss: 12.0151
Epoch 69/1000 | Train Loss: 12.0640 | Test Loss: 11.5841
Epoch 70/1000 | Train Loss: 12.0710 | Test Loss: 11.9260
Epoch 71/1000 | Train Loss: 12.0873 | Test Loss: 11.9467
Epoch 72/1000 | Train Loss: 12.2275 | Test Loss: 12.0173
Epoch 73/1000 | Train Loss: 12.0214 | Test Loss: 12.2340
Epoch 74/1000 | Train Loss: 12.1030 | Test Loss: 12.1632
Epoch 75/1000 | Train Loss: 11.9975 | Test Loss: 12.1081
Epoch 76/1000 | Train Loss: 12.0733 | Test Loss: 11.9612
Epoch 77/1000 | Train Loss: 12.0680 | Test Loss: 11.9965
Epoch 78/1000 | Train Loss: 11.9915 | Test Loss: 11.5750
Epoch 79/1000 | Train Loss: 12.1035 | Test Loss: 11.6641
Epoch 80/1000 | Train Loss: 12.0573 | Test Loss: 11.8880
Epoch 81/1000 | Train Loss: 12.0454 | Test Loss: 11.8869
Epoch 82/1000 | Train Loss: 12.2185 | Test Loss: 11.8352
Epoch 83/1000 | Train Loss: 11.8416 | Test Loss: 11.7722
Epoch 84/1000 | Train Loss: 11.9458 | Test Loss: 12.0839
Epoch 85/1000 | Train Loss: 11.9343 | Test Loss: 11.8382
Epoch 86/1000 | Train Loss: 11.7793 | Test Loss: 11.6629
Epoch 87/1000 | Train Loss: 11.9376 | Test Loss: 11.8517
Epoch 88/1000 | Train Loss: 11.9058 | Test Loss: 11.9674
Epoch 89/1000 | Train Loss: 11.9810 | Test Loss: 11.9530
Epoch 90/1000 | Train Loss: 11.9375 | Test Loss: 11.9583
Epoch 91/1000 | Train Loss: 11.9050 | Test Loss: 11.8016
Epoch 92/1000 | Train Loss: 11.9151 | Test Loss: 11.8442
Epoch 93/1000 | Train Loss: 11.9765 | Test Loss: 11.6780
Epoch 94/1000 | Train Loss: 11.7764 | Test Loss: 11.6647
Epoch 95/1000 | Train Loss: 11.9397 | Test Loss: 11.7466
Epoch 96/1000 | Train Loss: 11.8824 | Test Loss: 11.7991
Epoch 97/1000 | Train Loss: 11.8992 | Test Loss: 11.6441
Epoch 98/1000 | Train Loss: 11.9894 | Test Loss: 11.7137
Epoch 99/1000 | Train Loss: 11.8923 | Test Loss: 11.8582
Epoch 100/1000 | Train Loss: 11.9387 | Test Loss: 11.7990
Epoch 101/1000 | Train Loss: 11.8489 | Test Loss: 11.6433
Epoch 102/1000 | Train Loss: 11.8347 | Test Loss: 11.7863
Epoch 103/1000 | Train Loss: 11.8149 | Test Loss: 11.8709
Epoch 104/1000 | Train Loss: 11.8872 | Test Loss: 11.6314
Epoch 105/1000 | Train Loss: 11.9860 | Test Loss: 11.6282
Epoch 106/1000 | Train Loss: 11.7904 | Test Loss: 11.5144
Epoch 107/1000 | Train Loss: 11.9036 | Test Loss: 11.8842
Epoch 108/1000 | Train Loss: 11.8875 | Test Loss: 11.7332
Epoch 109/1000 | Train Loss: 11.9403 | Test Loss: 11.5999
Epoch 110/1000 | Train Loss: 11.7477 | Test Loss: 11.8332
Epoch 111/1000 | Train Loss: 11.7686 | Test Loss: 11.8716
Epoch 112/1000 | Train Loss: 11.8730 | Test Loss: 11.5982
Epoch 113/1000 | Train Loss: 11.8873 | Test Loss: 11.8883
Epoch 114/1000 | Train Loss: 11.8040 | Test Loss: 11.7196
Epoch 115/1000 | Train Loss: 11.8358 | Test Loss: 11.8426
Epoch 116/1000 | Train Loss: 11.8142 | Test Loss: 11.6931
Epoch 117/1000 | Train Loss: 11.8364 | Test Loss: 11.6850
Epoch 118/1000 | Train Loss: 11.7723 | Test Loss: 11.6787
Epoch 119/1000 | Train Loss: 11.8987 | Test Loss: 11.9505
Epoch 120/1000 | Train Loss: 11.8757 | Test Loss: 11.6119
Epoch 121/1000 | Train Loss: 11.8819 | Test Loss: 11.5585
Epoch 122/1000 | Train Loss: 11.8183 | Test Loss: 11.6900
Epoch 123/1000 | Train Loss: 11.8411 | Test Loss: 11.7078
Epoch 124/1000 | Train Loss: 11.8336 | Test Loss: 11.4010
Epoch 125/1000 | Train Loss: 11.8570 | Test Loss: 11.7128
Epoch 126/1000 | Train Loss: 11.7205 | Test Loss: 11.6107
Epoch 127/1000 | Train Loss: 11.8519 | Test Loss: 11.7438
Epoch 128/1000 | Train Loss: 11.7875 | Test Loss: 11.5455
Epoch 129/1000 | Train Loss: 11.6375 | Test Loss: 11.6376
Epoch 130/1000 | Train Loss: 11.8661 | Test Loss: 11.6423
Epoch 131/1000 | Train Loss: 11.7535 | Test Loss: 11.5839
Epoch 132/1000 | Train Loss: 11.9020 | Test Loss: 11.6695
Epoch 133/1000 | Train Loss: 11.7003 | Test Loss: 11.7084
Epoch 134/1000 | Train Loss: 11.7411 | Test Loss: 11.3230
Epoch 135/1000 | Train Loss: 11.7531 | Test Loss: 11.9448
Epoch 136/1000 | Train Loss: 11.9063 | Test Loss: 11.8530
Epoch 137/1000 | Train Loss: 11.6648 | Test Loss: 11.7637
Epoch 138/1000 | Train Loss: 11.8565 | Test Loss: 11.3914
Epoch 139/1000 | Train Loss: 11.6975 | Test Loss: 11.6894
Epoch 140/1000 | Train Loss: 11.7427 | Test Loss: 11.7183
Epoch 141/1000 | Train Loss: 11.8284 | Test Loss: 11.5880
Epoch 142/1000 | Train Loss: 11.8494 | Test Loss: 11.6434
Epoch 143/1000 | Train Loss: 11.7410 | Test Loss: 11.6091
Epoch 144/1000 | Train Loss: 11.8402 | Test Loss: 11.8378
Epoch 145/1000 | Train Loss: 11.7116 | Test Loss: 11.8754
Epoch 146/1000 | Train Loss: 11.8941 | Test Loss: 11.7678
Epoch 147/1000 | Train Loss: 11.7115 | Test Loss: 11.8005
Epoch 148/1000 | Train Loss: 11.7632 | Test Loss: 11.8169
Epoch 149/1000 | Train Loss: 11.7536 | Test Loss: 11.7337
Epoch 150/1000 | Train Loss: 11.8170 | Test Loss: 11.6149
Epoch 151/1000 | Train Loss: 11.8265 | Test Loss: 12.0937
Epoch 152/1000 | Train Loss: 11.7586 | Test Loss: 11.6340
Epoch 153/1000 | Train Loss: 11.7092 | Test Loss: 11.5450
Epoch 154/1000 | Train Loss: 11.8161 | Test Loss: 11.7703
Epoch 155/1000 | Train Loss: 11.6751 | Test Loss: 11.8580
Epoch 156/1000 | Train Loss: 11.7677 | Test Loss: 11.8970
Epoch 157/1000 | Train Loss: 11.7456 | Test Loss: 11.7207
Epoch 158/1000 | Train Loss: 11.8477 | Test Loss: 11.5301
Epoch 159/1000 | Train Loss: 11.7751 | Test Loss: 11.6020
Epoch 160/1000 | Train Loss: 11.6670 | Test Loss: 11.7262
Epoch 161/1000 | Train Loss: 11.6201 | Test Loss: 11.7380
Epoch 162/1000 | Train Loss: 11.5950 | Test Loss: 11.7042
Epoch 163/1000 | Train Loss: 11.8083 | Test Loss: 11.6820
Epoch 164/1000 | Train Loss: 11.8261 | Test Loss: 11.8257
Early stopping at epoch 164
Epoch 1/1000 | Train Loss: 15.3143 | Test Loss: 14.5206
Epoch 2/1000 | Train Loss: 14.5441 | Test Loss: 14.2272
Epoch 3/1000 | Train Loss: 14.3164 | Test Loss: 14.0423
Epoch 4/1000 | Train Loss: 14.1134 | Test Loss: 13.9909
Epoch 5/1000 | Train Loss: 14.0792 | Test Loss: 13.8330
Epoch 6/1000 | Train Loss: 13.7882 | Test Loss: 13.6578
Epoch 7/1000 | Train Loss: 13.6744 | Test Loss: 13.7155
Epoch 8/1000 | Train Loss: 13.5305 | Test Loss: 13.5384
Epoch 9/1000 | Train Loss: 13.5408 | Test Loss: 13.4030
Epoch 10/1000 | Train Loss: 13.5270 | Test Loss: 13.5321
Epoch 11/1000 | Train Loss: 13.4275 | Test Loss: 13.3616
Epoch 12/1000 | Train Loss: 13.3409 | Test Loss: 13.4895
Epoch 13/1000 | Train Loss: 13.2326 | Test Loss: 13.1596
Epoch 14/1000 | Train Loss: 13.1596 | Test Loss: 13.1739
Epoch 15/1000 | Train Loss: 13.0184 | Test Loss: 13.1595
Epoch 16/1000 | Train Loss: 12.8999 | Test Loss: 12.8433
Epoch 17/1000 | Train Loss: 12.8725 | Test Loss: 12.9558
Epoch 18/1000 | Train Loss: 12.8969 | Test Loss: 12.7351
Epoch 19/1000 | Train Loss: 12.7585 | Test Loss: 12.6521
Epoch 20/1000 | Train Loss: 12.6703 | Test Loss: 12.5746
Epoch 21/1000 | Train Loss: 12.5997 | Test Loss: 12.6308
Epoch 22/1000 | Train Loss: 12.4540 | Test Loss: 12.5919
Epoch 23/1000 | Train Loss: 12.5542 | Test Loss: 12.4800
Epoch 24/1000 | Train Loss: 12.4466 | Test Loss: 12.5020
Epoch 25/1000 | Train Loss: 12.3637 | Test Loss: 12.2970
Epoch 26/1000 | Train Loss: 12.3854 | Test Loss: 12.6369
Epoch 27/1000 | Train Loss: 12.4289 | Test Loss: 12.2421
Epoch 28/1000 | Train Loss: 12.3948 | Test Loss: 12.5283
Epoch 29/1000 | Train Loss: 12.2724 | Test Loss: 12.5595
Epoch 30/1000 | Train Loss: 12.3463 | Test Loss: 12.5177
Epoch 31/1000 | Train Loss: 12.2532 | Test Loss: 12.5354
Epoch 32/1000 | Train Loss: 12.3209 | Test Loss: 12.3344
Epoch 33/1000 | Train Loss: 12.3379 | Test Loss: 12.3521
Epoch 34/1000 | Train Loss: 12.3143 | Test Loss: 12.2789
Epoch 35/1000 | Train Loss: 12.3133 | Test Loss: 12.5243
Epoch 36/1000 | Train Loss: 12.2432 | Test Loss: 12.4296
Epoch 37/1000 | Train Loss: 12.2789 | Test Loss: 12.4316
Epoch 38/1000 | Train Loss: 12.2282 | Test Loss: 12.3293
Epoch 39/1000 | Train Loss: 12.2502 | Test Loss: 12.4584
Epoch 40/1000 | Train Loss: 12.3091 | Test Loss: 12.4258
Epoch 41/1000 | Train Loss: 12.1638 | Test Loss: 12.3796
Epoch 42/1000 | Train Loss: 12.2763 | Test Loss: 12.3078
Epoch 43/1000 | Train Loss: 12.2656 | Test Loss: 12.4695
Epoch 44/1000 | Train Loss: 12.3021 | Test Loss: 12.4233
Epoch 45/1000 | Train Loss: 12.1426 | Test Loss: 12.2583
Epoch 46/1000 | Train Loss: 12.2719 | Test Loss: 12.2178
Epoch 47/1000 | Train Loss: 12.2593 | Test Loss: 12.2725
Epoch 48/1000 | Train Loss: 12.2383 | Test Loss: 12.3514
Epoch 49/1000 | Train Loss: 12.1181 | Test Loss: 12.4072
Epoch 50/1000 | Train Loss: 12.0836 | Test Loss: 12.3616
Epoch 51/1000 | Train Loss: 12.2459 | Test Loss: 12.2764
Epoch 52/1000 | Train Loss: 12.1774 | Test Loss: 12.2319
Epoch 53/1000 | Train Loss: 12.1756 | Test Loss: 12.3363
Epoch 54/1000 | Train Loss: 12.0998 | Test Loss: 12.3563
Epoch 55/1000 | Train Loss: 12.0769 | Test Loss: 12.1846
Epoch 56/1000 | Train Loss: 12.1658 | Test Loss: 12.1814
Epoch 57/1000 | Train Loss: 12.2584 | Test Loss: 12.2048
Epoch 58/1000 | Train Loss: 12.1298 | Test Loss: 12.2736
Epoch 59/1000 | Train Loss: 12.1608 | Test Loss: 12.2898
Epoch 60/1000 | Train Loss: 12.0239 | Test Loss: 12.3782
Epoch 61/1000 | Train Loss: 12.1477 | Test Loss: 12.0479
Epoch 62/1000 | Train Loss: 12.1227 | Test Loss: 12.4227
Epoch 63/1000 | Train Loss: 12.1196 | Test Loss: 12.2449
Epoch 64/1000 | Train Loss: 12.0207 | Test Loss: 12.3184
Epoch 65/1000 | Train Loss: 12.0844 | Test Loss: 12.4369
Epoch 66/1000 | Train Loss: 12.1261 | Test Loss: 12.3580
Epoch 67/1000 | Train Loss: 12.0345 | Test Loss: 12.2135
Epoch 68/1000 | Train Loss: 12.0888 | Test Loss: 12.2756
Epoch 69/1000 | Train Loss: 12.1119 | Test Loss: 12.2135
Epoch 70/1000 | Train Loss: 11.9869 | Test Loss: 12.2734
Epoch 71/1000 | Train Loss: 12.0929 | Test Loss: 12.0073
Epoch 72/1000 | Train Loss: 12.1204 | Test Loss: 12.1544
Epoch 73/1000 | Train Loss: 11.9838 | Test Loss: 12.1363
Epoch 74/1000 | Train Loss: 12.1360 | Test Loss: 12.2834
Epoch 75/1000 | Train Loss: 11.9912 | Test Loss: 12.2601
Epoch 76/1000 | Train Loss: 12.0332 | Test Loss: 12.1388
Epoch 77/1000 | Train Loss: 12.0875 | Test Loss: 12.1467
Epoch 78/1000 | Train Loss: 11.9644 | Test Loss: 12.3583
Epoch 79/1000 | Train Loss: 12.0888 | Test Loss: 12.2773
Epoch 80/1000 | Train Loss: 12.0712 | Test Loss: 12.3570
Epoch 81/1000 | Train Loss: 12.0366 | Test Loss: 12.1873
Epoch 82/1000 | Train Loss: 12.0446 | Test Loss: 12.2192
Epoch 83/1000 | Train Loss: 11.9588 | Test Loss: 12.1476
Epoch 84/1000 | Train Loss: 12.0591 | Test Loss: 12.1768
Epoch 85/1000 | Train Loss: 11.9590 | Test Loss: 12.2443
Epoch 86/1000 | Train Loss: 11.9610 | Test Loss: 12.2786
Epoch 87/1000 | Train Loss: 11.9929 | Test Loss: 12.4902
Epoch 88/1000 | Train Loss: 11.9258 | Test Loss: 12.2453
Epoch 89/1000 | Train Loss: 11.9784 | Test Loss: 12.1693
Epoch 90/1000 | Train Loss: 12.0173 | Test Loss: 12.1154
Epoch 91/1000 | Train Loss: 11.9839 | Test Loss: 12.1429
Epoch 92/1000 | Train Loss: 11.8322 | Test Loss: 12.2297
Epoch 93/1000 | Train Loss: 11.9215 | Test Loss: 12.2400
Epoch 94/1000 | Train Loss: 12.0014 | Test Loss: 12.0487
Epoch 95/1000 | Train Loss: 12.0156 | Test Loss: 12.1343
Epoch 96/1000 | Train Loss: 11.9593 | Test Loss: 12.3031
Epoch 97/1000 | Train Loss: 11.9641 | Test Loss: 12.1161
Epoch 98/1000 | Train Loss: 12.0123 | Test Loss: 12.0265
Epoch 99/1000 | Train Loss: 11.9532 | Test Loss: 12.3249
Epoch 100/1000 | Train Loss: 11.8445 | Test Loss: 12.3431
Epoch 101/1000 | Train Loss: 12.0422 | Test Loss: 12.0286
Early stopping at epoch 101
Epoch 1/1000 | Train Loss: 13.8512 | Test Loss: 13.3184
Epoch 2/1000 | Train Loss: 13.1148 | Test Loss: 12.8622
Epoch 3/1000 | Train Loss: 12.6885 | Test Loss: 12.4904
Epoch 4/1000 | Train Loss: 12.3525 | Test Loss: 12.2722
Epoch 5/1000 | Train Loss: 12.1988 | Test Loss: 12.0703
Epoch 6/1000 | Train Loss: 12.0832 | Test Loss: 12.0370
Epoch 7/1000 | Train Loss: 12.0707 | Test Loss: 12.0223
Epoch 8/1000 | Train Loss: 12.0249 | Test Loss: 11.9396
Epoch 9/1000 | Train Loss: 12.0421 | Test Loss: 11.8700
Epoch 10/1000 | Train Loss: 11.9588 | Test Loss: 11.9392
Epoch 11/1000 | Train Loss: 11.9607 | Test Loss: 11.9116
Epoch 12/1000 | Train Loss: 11.9420 | Test Loss: 11.8527
Epoch 13/1000 | Train Loss: 11.9383 | Test Loss: 11.9190
Epoch 14/1000 | Train Loss: 11.9168 | Test Loss: 11.8795
Epoch 15/1000 | Train Loss: 11.8959 | Test Loss: 11.8604
Epoch 16/1000 | Train Loss: 11.9284 | Test Loss: 11.9318
Epoch 17/1000 | Train Loss: 11.8907 | Test Loss: 11.8615
Epoch 18/1000 | Train Loss: 11.8946 | Test Loss: 11.8777
Epoch 19/1000 | Train Loss: 11.9446 | Test Loss: 11.9216
Epoch 20/1000 | Train Loss: 11.8811 | Test Loss: 11.7708
Epoch 21/1000 | Train Loss: 11.8997 | Test Loss: 11.9013
Epoch 22/1000 | Train Loss: 11.9081 | Test Loss: 11.8816
Epoch 23/1000 | Train Loss: 11.8824 | Test Loss: 11.7657
Epoch 24/1000 | Train Loss: 11.8692 | Test Loss: 11.8659
Epoch 25/1000 | Train Loss: 11.8944 | Test Loss: 11.8023
Epoch 26/1000 | Train Loss: 11.8753 | Test Loss: 11.8261
Epoch 27/1000 | Train Loss: 11.8978 | Test Loss: 11.8922
Epoch 28/1000 | Train Loss: 11.8536 | Test Loss: 11.8470
Epoch 29/1000 | Train Loss: 11.8996 | Test Loss: 11.7920
Epoch 30/1000 | Train Loss: 11.8790 | Test Loss: 11.8041
Epoch 31/1000 | Train Loss: 11.8674 | Test Loss: 11.8703
Epoch 32/1000 | Train Loss: 11.8784 | Test Loss: 11.7797
Epoch 33/1000 | Train Loss: 11.8858 | Test Loss: 11.7773
Epoch 34/1000 | Train Loss: 11.8610 | Test Loss: 11.7589
Epoch 35/1000 | Train Loss: 11.8815 | Test Loss: 11.8217
Epoch 36/1000 | Train Loss: 11.8812 | Test Loss: 11.8198
Epoch 37/1000 | Train Loss: 11.8621 | Test Loss: 11.8226
Epoch 38/1000 | Train Loss: 11.8507 | Test Loss: 11.8185
Epoch 39/1000 | Train Loss: 11.8386 | Test Loss: 11.8318
Epoch 40/1000 | Train Loss: 11.8662 | Test Loss: 11.8078
Epoch 41/1000 | Train Loss: 11.9042 | Test Loss: 11.8097
Epoch 42/1000 | Train Loss: 11.8820 | Test Loss: 11.7710
Epoch 43/1000 | Train Loss: 11.8605 | Test Loss: 11.7969
Epoch 44/1000 | Train Loss: 11.8812 | Test Loss: 11.8130
Epoch 45/1000 | Train Loss: 11.8390 | Test Loss: 11.7780
Epoch 46/1000 | Train Loss: 11.8473 | Test Loss: 11.7617
Epoch 47/1000 | Train Loss: 11.8773 | Test Loss: 11.7419
Epoch 48/1000 | Train Loss: 11.8183 | Test Loss: 11.8625
Epoch 49/1000 | Train Loss: 11.8523 | Test Loss: 11.8417
Epoch 50/1000 | Train Loss: 11.8363 | Test Loss: 11.8311
Epoch 51/1000 | Train Loss: 11.8356 | Test Loss: 11.8375
Epoch 52/1000 | Train Loss: 11.8816 | Test Loss: 11.8135
Epoch 53/1000 | Train Loss: 11.8676 | Test Loss: 11.8157
Epoch 54/1000 | Train Loss: 11.8574 | Test Loss: 11.7945
Epoch 55/1000 | Train Loss: 11.8510 | Test Loss: 11.8232
Epoch 56/1000 | Train Loss: 11.8683 | Test Loss: 11.7542
Epoch 57/1000 | Train Loss: 11.8507 | Test Loss: 11.8625
Epoch 58/1000 | Train Loss: 11.8392 | Test Loss: 11.7640
Epoch 59/1000 | Train Loss: 11.8549 | Test Loss: 11.8906
Epoch 60/1000 | Train Loss: 11.8517 | Test Loss: 11.7967
Epoch 61/1000 | Train Loss: 11.8375 | Test Loss: 11.8566
Epoch 62/1000 | Train Loss: 11.8186 | Test Loss: 11.8473
Epoch 63/1000 | Train Loss: 11.8239 | Test Loss: 11.8126
Epoch 64/1000 | Train Loss: 11.8511 | Test Loss: 11.7618
Epoch 65/1000 | Train Loss: 11.8414 | Test Loss: 11.7448
Epoch 66/1000 | Train Loss: 11.8294 | Test Loss: 11.8077
Epoch 67/1000 | Train Loss: 11.8539 | Test Loss: 11.7916
Epoch 68/1000 | Train Loss: 11.8658 | Test Loss: 11.7878
Epoch 69/1000 | Train Loss: 11.8615 | Test Loss: 11.8276
Epoch 70/1000 | Train Loss: 11.8254 | Test Loss: 11.7756
Epoch 71/1000 | Train Loss: 11.8323 | Test Loss: 11.8589
Epoch 72/1000 | Train Loss: 11.7861 | Test Loss: 11.7638
Epoch 73/1000 | Train Loss: 11.8139 | Test Loss: 11.7869
Epoch 74/1000 | Train Loss: 11.8221 | Test Loss: 11.8388
Epoch 75/1000 | Train Loss: 11.8876 | Test Loss: 11.8327
Epoch 76/1000 | Train Loss: 11.8523 | Test Loss: 11.8629
Epoch 77/1000 | Train Loss: 11.8545 | Test Loss: 11.8116
Early stopping at epoch 77
fig, axs = plt.subplots(1, 3, figsize=(8, 6))
axs=axs.flatten()
losses_df_500[['train_loss', 'test_loss']].plot(ax=axs[0])
losses_df_1000[['train_loss', 'test_loss']].plot(ax=axs[1])
losses_df_10_000[['train_loss', 'test_loss']].plot(ax=axs[2])
axs[0].set_title("Training and Test Losses \n 500 observations");
axs[1].set_title("Training and Test Losses \n 1000 observations");
axs[2].set_title("Training and Test Losses \n 10_000 observations");def bootstrap_residuals(vae_fit, X_test, sample_df, n_boot=1000):
recons = []
resid_array = np.zeros((n_boot, len(sample_df.columns), len(sample_df.columns)))
for i in range(n_boot):
recon_data = vae_fit.generate(n_samples=len(X_test))
reconstructed_df = pd.DataFrame(recon_data, columns=sample_df.columns)
resid = pd.DataFrame(X_test, columns=sample_df.columns).corr() - reconstructed_df.corr()
resid_array[i] = resid.values
recons.append(reconstructed_df)
avg_resid = resid_array.mean(axis=0)
bootstrapped_resids = pd.DataFrame(avg_resid, columns=sample_df.columns, index=sample_df.columns)
return bootstrapped_resids
bootstrapped_resids_500 = bootstrap_residuals(vae_fit_500, pd.DataFrame(test_500.dataset, columns=sample_df.columns), sample_df)
bootstrapped_resids_1000 = bootstrap_residuals(vae_fit_1000, pd.DataFrame(test_1000.dataset, columns=sample_df.columns), sample_df)
bootstrapped_resids_10_000 = bootstrap_residuals(vae_fit_10_000, pd.DataFrame(test_10_000.dataset, columns=sample_df.columns), sample_df)
fig, axs = plt.subplots(3, 1, figsize=(10, 20))
axs = axs.flatten()
plot_heatmap(bootstrapped_resids_500, title="""Expected Correlation Residuals for 500 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[0], colorbar=True, vmin=-.25, vmax=.25)
plot_heatmap(bootstrapped_resids_1000, title="""Expected Correlation Residuals for 1000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[1], colorbar=True, vmin=-.25, vmax=.25)
plot_heatmap(bootstrapped_resids_10_000, title="""Expected Correlation Residuals for 10,000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[2], colorbar=True, vmin=-.25, vmax=.25)/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
Missing Data
sample_df_missing = sample_df.copy()
# Randomly pick 5% of the total elements
mask_remove = np.random.rand(*sample_df_missing.shape) < 0.05
# Set those elements to NaN
sample_df_missing[mask_remove] = np.nan
sample_df_missing.head()| JW1 | JW2 | JW3 | UF1 | UF2 | FOR | DA1 | DA2 | DA3 | EBA | ST | MI | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.330438 | 0.326503 | 0.583599 | 0.298558 | 1.112858 | 0.371895 | 0.247663 | 0.704001 | 0.739605 | NaN | 1.546412 | 1.282066 |
| 1 | NaN | -0.598124 | -0.882404 | 0.034036 | -0.220583 | -0.443250 | 0.452083 | 0.976292 | 1.460018 | 0.208454 | -0.237027 | 0.124897 |
| 2 | -0.334853 | -0.171359 | -0.862147 | -0.696685 | 0.294389 | -0.671320 | 0.023049 | -0.117460 | 0.394511 | 0.769453 | 1.138158 | 0.216388 |
| 3 | -0.765717 | 0.554349 | 0.062522 | 0.181065 | 0.609657 | 0.595781 | -0.056995 | -0.635932 | -0.330862 | -0.424084 | -0.548190 | NaN |
| 4 | NaN | -0.369919 | -0.210114 | 0.185822 | -0.755927 | -0.490341 | -0.472250 | -0.419797 | 0.123084 | 1.157508 | -0.840009 | -0.144719 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class MissingDataDataset(Dataset):
def __init__(self, x, mask):
# x and mask are tensors of same shape
self.x = x
self.mask = mask
def __len__(self):
return self.x.shape[0]
def __getitem__(self, idx):
return self.x[idx], self.mask[idx]
def prep_data_vae_missing(sample_size=1000, batch_size=32):
sample_df = make_sample(cov_matrix=cov_matrix, size=sample_size, columns=columns)
X_train, X_test = train_test_split(sample_df.values, test_size=0.2, random_state=890)
# Mask: 1=observed, 0=missing
mask_train = ~pd.DataFrame(X_train).isna()
mask_test = ~pd.DataFrame(X_test).isna()
# Tensors (keep NaNs for missing values)
x_train_tensor = torch.tensor(X_train, dtype=torch.float32)
mask_train_tensor = torch.tensor(mask_train.values, dtype=torch.float32)
x_test_tensor = torch.tensor(X_test, dtype=torch.float32)
mask_test_tensor = torch.tensor(mask_test.values, dtype=torch.float32)
train_dataset = MissingDataDataset(x_train_tensor, mask_train_tensor)
test_dataset = MissingDataDataset(x_test_tensor, mask_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
def vae_loss_missing(recon_mu, recon_logvar, x_filled, mu, logvar, mask):
"""
VAE loss that skips missing values (NaNs) in x for the reconstruction term.
"""
# Reconstruction loss (Gaussian NLL) only on observed values
recon_var = torch.exp(recon_logvar)
recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) + (x_filled - recon_mu) ** 2 / recon_var)
# Apply mask and normalize by number of observed features per sample
recon_nll = recon_nll * mask # zero-out missing features
obs_counts = mask.sum(dim=1).clamp(min=1) # avoid division by 0
recon_loss = (recon_nll.sum(dim=1) / obs_counts).mean()
# KL divergence as usual
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
kl_loss = kl_div.mean()
return recon_loss, kl_lossimport torch
import torch.nn as nn
import torch.nn.functional as F
class NumericVAE_missing(nn.Module):
def __init__(self, n_features, hidden_dim=64, latent_dim=8):
super().__init__()
self.n_features = n_features
# ---------- Learnable Imputation ----------
# One learnable parameter per feature for missing values
self.missing_embeddings = nn.Parameter(torch.zeros(n_features))
# ---------- ENCODER ----------
self.fc1_x = nn.Linear(n_features, hidden_dim)
# Stronger mask encoder: 2-layer MLP
self.fc1_mask = nn.Sequential(
nn.Linear(n_features, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# Combine feature and mask embeddings
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# ---------- DECODER ----------
self.fc2 = nn.Linear(latent_dim, hidden_dim)
self.fc_out_mu = nn.Linear(hidden_dim, n_features)
self.fc_out_logvar = nn.Linear(hidden_dim, n_features)
def encode(self, x, mask):
# Impute missing values with learnable parameters
x_filled = torch.where(
torch.isnan(x),
self.missing_embeddings.expand_as(x),
x
)
# Encode features and mask separately
h_x = F.relu(self.fc1_x(x_filled))
h_mask = self.fc1_mask(mask)
# Combine embeddings
h = h_x + h_mask
# Latent space
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = F.relu(self.fc2(z))
recon_mu = self.fc_out_mu(h)
recon_logvar = self.fc_out_logvar(h)
return recon_mu, recon_logvar
def forward(self, x, mask):
mu, logvar = self.encode(x, mask)
z = self.reparameterize(mu, logvar)
recon_mu, recon_logvar = self.decode(z)
return (recon_mu, recon_logvar), mu, logvar
def generate(self, n_samples=100):
self.eval()
with torch.no_grad():
z = torch.randn(n_samples, self.fc_mu.out_features)
recon_mu, recon_logvar = self.decode(z)
return recon_mu + torch.exp(0.5 * recon_logvar) * torch.randn_like(recon_mu)def vae_loss_missing(recon_mu, recon_logvar, x, mu, logvar, mask):
# Fill missing values with 0 just for loss computation
x_filled = torch.where(mask.bool(), x, torch.zeros_like(x))
recon_var = torch.exp(recon_logvar)
recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) +
(x_filled - recon_mu) ** 2 / recon_var)
# Mask out missing values
recon_nll = recon_nll * mask
obs_counts = mask.sum(dim=1).clamp(min=1)
recon_loss = (recon_nll.sum(dim=1) / obs_counts).mean()
# KL divergence
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
kl_loss = kl_div.mean()
return recon_loss, kl_loss
def beta_annealing(epoch, max_beta=1.0, anneal_epochs=100):
beta = min(max_beta, max_beta * epoch / anneal_epochs)
return betatrain_loader, test_loader = prep_data_vae_missing(10_000, batch_size=32)
vae_missing = NumericVAE_missing(n_features=next(iter(train_loader))[0].shape[1])
optimizer = optim.Adam(vae_missing.parameters(), lr=1e-4)
best_loss = float('inf')
patience, wait = 30, 0
losses = []
n_epochs = 1000
for epoch in range(n_epochs):
beta = beta_annealing(epoch, max_beta=1.0, anneal_epochs=10)
vae_missing.train()
train_loss = 0
for x_batch, mask_batch in train_loader:
optimizer.zero_grad()
(recon_mu, recon_logvar), mu, logvar = vae_missing(x_batch, mask_batch)
recon_loss, kl_loss = vae_loss_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
loss = recon_loss + beta * kl_loss
loss.backward()
optimizer.step()
train_loss += loss.item() * x_batch.size(0)
avg_train_loss = train_loss / len(train_loader.dataset)
# --- Validation ---
vae_missing.eval()
test_loss = 0.0
with torch.no_grad():
for x_batch, mask_batch in test_loader:
(recon_mu, recon_logvar), mu, logvar = vae_missing(x_batch, mask_batch)
recon_loss, kl_loss = vae_loss_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
loss = recon_loss + kl_loss
test_loss += loss.item() * x_batch.size(0)
avg_test_loss = test_loss / len(test_loader.dataset)
print(f"Epoch {epoch+1}/{n_epochs} | Train: {avg_train_loss:.4f} | Test: {avg_test_loss:.4f}")
# Early stopping
if test_loss < best_loss - 1e-4:
best_loss, wait = test_loss, 0
best_state = vae_missing.state_dict()
else:
wait += 1
if wait >= patience:
print(f"Early stopping at epoch {epoch+1}")
vae_missing.load_state_dict(best_state) # restore best
break
losses.append([avg_train_loss, avg_test_loss, best_loss])bootstrapped_resids_500 = bootstrap_residuals(vae_missing, pd.DataFrame(test_loader.dataset.x, columns=sample_df.columns), sample_df)
plot_heatmap(bootstrapped_resids_500)/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
recons = []
n_boot = 500
resid_array = np.zeros((n_boot, len(sample_df_missing.columns), len(sample_df_missing.columns)))
for i in range(500):
recon_data = vae_missing.generate(n_samples=len(sample_df_missing))
reconstructed_df = pd.DataFrame(recon_data, columns=sample_df_missing.columns)
resid = pd.DataFrame(test_loader.dataset.x, columns=sample_df_missing.columns).corr() - reconstructed_df.corr()
resid_array[i] = resid.values
recons.append(reconstructed_df)
avg_resid = resid_array.mean(axis=0)
bootstrapped_resids = pd.DataFrame(avg_resid, columns=sample_df_missing.columns, index=sample_df_missing.columns)
plot_heatmap(bootstrapped_resids, title="""Expected Residuals \n Under Bootstrapped Reconstructions""")/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
recon_data = vae_missing.generate(n_samples=len(sample_df_missing))
# Rebuild imputed DataFrame
imputed_array = sample_df_missing.to_numpy().copy()
imputed_array[mask_remove] = recon_data[mask_remove]
imputed_df = pd.DataFrame(imputed_array, columns=sample_df_missing.columns)
imputed_df.head()| JW1 | JW2 | JW3 | UF1 | UF2 | FOR | DA1 | DA2 | DA3 | EBA | ST | MI | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.330438 | 0.326503 | 0.583599 | 0.298558 | 1.112858 | 0.371895 | 0.247663 | 0.704001 | 0.739605 | -0.210920 | 1.546412 | 1.282066 |
| 1 | -0.592279 | -0.598124 | -0.882404 | 0.034036 | -0.220583 | -0.443250 | 0.452083 | 0.976292 | 1.460018 | 0.208454 | -0.237027 | 0.124897 |
| 2 | -0.334853 | -0.171359 | -0.862147 | -0.696685 | 0.294389 | -0.671320 | 0.023049 | -0.117460 | 0.394511 | 0.769453 | 1.138158 | 0.216388 |
| 3 | -0.765717 | 0.554349 | 0.062522 | 0.181065 | 0.609657 | 0.595781 | -0.056995 | -0.635932 | -0.330862 | -0.424084 | -0.548190 | 2.183827 |
| 4 | 0.946249 | -0.369919 | -0.210114 | 0.185822 | -0.755927 | -0.490341 | -0.472250 | -0.419797 | 0.123084 | 1.157508 | -0.840009 | -0.144719 |
plot_heatmap(sample_df.corr() - imputed_df.corr())/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
fig, axs = plt.subplots(1,2 ,figsize=(9, 30))
axs = axs.flatten()
plot_heatmap(sample_df_missing.head(50).fillna(99), vmin=-0, vmax=99, ax=axs[0], colorbar=False)
axs[0].set_title("Missng Data", fontsize=12)
plot_heatmap(imputed_df.head(50), vmin=-2, vmax=2, ax=axs[1], colorbar=False)
axs[1].set_title("Imputed Data", fontsize=12);
plt.tight_layout()/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
Bayesian Inference
def make_pymc_model(sample_df):
coords = {'features': sample_df.columns,
'features1': sample_df.columns ,
'obs': range(len(sample_df))}
with pm.Model(coords=coords) as model:
# Priors
mus = pm.Normal("mus", 0, 1, dims='features')
chol, _, _ = pm.LKJCholeskyCov("chol", n=12, eta=1.0, sd_dist=pm.HalfNormal.dist(1))
cov = pm.Deterministic('cov', pm.math.dot(chol, chol.T), dims=('features', 'features1'))
pm.MvNormal('likelihood', mus, cov=cov, observed=sample_df.values, dims=('obs', 'features'))
idata = pm.sample_prior_predictive()
idata.extend(pm.sample(random_seed=120))
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
return idata, model
idata, model = make_pymc_model(sample_df)pm.model_to_graphviz(model)import arviz as az
expected_corr = pd.DataFrame(az.summary(idata, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)
resids = sample_df.corr() - expected_corr
plot_heatmap(resids)/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
Missing Data
idata_missing, model_missing = make_pymc_model(sample_df_missing)pm.model_to_graphviz(model_missing)expected_corr = pd.DataFrame(az.summary(idata_missing, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)
resids = sample_df.corr() - expected_corr
plot_heatmap(resids)/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_97259/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
Citation
BibTeX citation:
@online{forde2025,
author = {Forde, Nathaniel},
title = {Amortized {Bayesian} {Inference} with {PyTorch}},
date = {2025-07-25},
langid = {en},
abstract = {The cost of generating new sample data can be prohibitive.
There is a secondary but different cost which attaches to the
“construction” of novel data. Principal Components Analysis can be
seen as a technique to optimally reconstruct a complex multivariate
data set from a lower level compressed dimensional space.
Variational auto-encoders allow us to achieve yet more flexible
reconstruction results in non-linear cases. Drawing a new sample
from the posterior predictive distribution of Bayesian models
similarly supplies us with insight in the variability of realised
data. Both methods assume a latent model of the data generating
process that aims to leverage a compressed representation of the
data. These are different heuristics with different consequences for
how we understand the variability in the world. Amortized Bayesian
inference seeks to unite the two heuristics.}
}
For attribution, please cite this work as:
Forde, Nathaniel. 2025. “Amortized Bayesian Inference with
PyTorch.” July 25, 2025.